import argparse
import json
import os

from run import run_sac
from UnitCell_Environment.unitcell_environment.env.utils import load_compositions

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Optional app description')
    parser.add_argument("--algo_config", type=str, default="config.json",
                        help="The config file for tunable parameters of the chosen RL method")
    parser.add_argument("--env_config", type=str, default="env_config.json",
                        help="The config file for the env parameters")
    parser.add_argument("--policy", type=str, default=os.path.abspath("policy"),
                        help="the folder containing the policy to be used")
    parser.add_argument("--comp_config", type=str, default="comp_config.json",
                        help="the config file with compositions and structure sizes supported")
    parser.add_argument("--comp", type=str, default="SrTiO3x8",
                        help="specific composition/structure size to optimize")
    parser.add_argument("--input_dir", type=str, default="starting_structures/SrTiO3x8",
                        help="input dir with starting structures")
    parser.add_argument("--output_dir", type=str, default=None,
                        help="The output dir to save the optimized structures and the energy history")

    args = parser.parse_args()

    with open(args.algo_config, 'r') as f:
        algo_config = json.load(f)

    algo_config["num_workers"] = 0

    with open(args.env_config, "r") as file:
        env_config = json.load(file)

    env_config["gamma"] = algo_config.get("gamma", 0.99)
    env_config["gpu"] = algo_config.get("num_gpus", 0) > 0
    env_config["agent_levels"] = ["2_stepsize_"]

    if args.output_dir != None:
        env_config["output_dir"] = args.output_dir
    
    if args.comp is not None:
        env_config["comp"] = args.comp
    env_config["store_trajectory"] = False

    with open(args.comp_config, 'r') as f:
        comp_config = json.load(f)

    load_compositions(comp_config)

    checkpoint = args.policy

    checkpoint = run_sac(env_config, 
                        algo_config,
                        input_dir=args.input_dir,
                        checkpoint=checkpoint)
